{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "LaTeX-OCR training.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Train a LaTeX OCR model\n",
        "In this brief notebook I show how you can finetune/train an OCR model.\n",
        "\n",
        "I've opted to mix in handwritten data into the regular pdf LaTeX images. For that I started out with the released pretrained model and continued training on the slightly larger corpus."
      ],
      "metadata": {
        "id": "YtR1GhYwnLnu"
      }
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r396ah-Q3EQc"
      },
      "source": [
        "!pip install pix2tex[train] -qq"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dZ4PLwkb3RIs"
      },
      "source": [
        "import os\n",
        "!mkdir -p LaTeX-OCR\n",
        "os.chdir('LaTeX-OCR')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cUsTlxXV3Mot"
      },
      "source": [
        "!pip install gpustat -q\n",
        "!pip install opencv-python-headless==4.1.2.30 -U -q\n",
        "!pip install --upgrade --no-cache-dir gdown -q"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# check what GPU we have\n",
        "!gpustat"
      ],
      "metadata": {
        "id": "uhLzh5vyaCaL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aAz37dDU21zu"
      },
      "source": [
        "!mkdir -p dataset/data\n",
        "!mkdir images\n",
        "# Google Drive ids\n",
        "# handwritten: 13vjxGYrFCuYnwgDIUqkxsNGKk__D_sOM\n",
        "# pdf - images: 176PKaCUDWmTJdQwc-OfkO0y8t4gLsIvQ\n",
        "# pdf - math: 1QUjX6PFWPa-HBWdcY-7bA5TRVUnbyS1D\n",
        "!gdown -O dataset/data/crohme.zip --id 13vjxGYrFCuYnwgDIUqkxsNGKk__D_sOM\n",
        "!gdown -O dataset/data/pdf.zip --id 176PKaCUDWmTJdQwc-OfkO0y8t4gLsIvQ\n",
        "!gdown -O dataset/data/pdfmath.txt --id 1QUjX6PFWPa-HBWdcY-7bA5TRVUnbyS1D\n",
        "os.chdir('dataset/data')\n",
        "!unzip -q crohme.zip \n",
        "!unzip -q pdf.zip \n",
        "# split handwritten data into val set and train set\n",
        "os.chdir('images')\n",
        "!mkdir ../valimages\n",
        "!ls | shuf -n 1000 | xargs -i mv {} ../valimages\n",
        "os.chdir('../../..')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now we generate the datasets. We can string multiple datasets together to get one large lookup table. The only thing saved in these pkl files are image sizes, image location and the ground truth latex code. That way we can serve batches of images with the same dimensionality."
      ],
      "metadata": {
        "id": "2BMuIqRIqG-8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python -m pix2tex.dataset.dataset -i dataset/data/images dataset/data/train -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt -o dataset/data/train.pkl"
      ],
      "metadata": {
        "id": "1JebcEarl-g6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!python -m pix2tex.dataset.dataset -i dataset/data/valimages dataset/data/val -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt -o dataset/data/val.pkl"
      ],
      "metadata": {
        "id": "x_Orutb37xHD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# download the weights we want to fine tune\n",
        "!curl -L -o weights.pth https://github.com/lukas-blecher/LaTeX-OCR/releases/download/v0.0.1/weights.pth"
      ],
      "metadata": {
        "id": "I3iOyEEBbw58"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# If using wandb\n",
        "!pip install -q wandb \n",
        "# you can cancel this if you don't wan't to use it or don't have a W&B acc.\n",
        "#!wandb login"
      ],
      "metadata": {
        "id": "vow2NnpHmWt0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# generate colab specific config (set 'debug' to true if wandb is not used)\n",
        "!echo {backbone_layers: [2, 3, 7], betas: [0.9, 0.999], batchsize: 10, bos_token: 1, channels: 1, data: dataset/data/train.pkl, debug: true, decoder_args: {'attn_on_attn': true, 'cross_attend': true, 'ff_glu': true, 'rel_pos_bias': false, 'use_scalenorm': false}, dim: 256, encoder_depth: 4, eos_token: 2, epochs: 50, gamma: 0.9995, heads: 8, id: null, load_chkpt: 'weights.pth', lr: 0.001, lr_step: 30, max_height: 192, max_seq_len: 512, max_width: 672, min_height: 32, min_width: 32, model_path: checkpoints, name: mixed, num_layers: 4, num_tokens: 8000, optimizer: Adam, output_path: outputs, pad: false, pad_token: 0, patch_size: 16, sample_freq: 2000, save_freq: 1, scheduler: StepLR, seed: 42, temperature: 0.2, test_samples: 5, testbatchsize: 20, tokenizer: dataset/tokenizer.json, valbatches: 100, valdata: dataset/data/val.pkl} > colab.yaml"
      ],
      "metadata": {
        "id": "OnsNCLp84QSY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c8NU5j2k3z36"
      },
      "source": [
        "!python -m pix2tex.train --config colab.yaml"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "g3DU9KxubWgq"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}